#!/usr/bin/env python3

import argparse
import os

import lightning as L
from lightning.pytorch.loggers import WandbLogger, CSVLogger
from lightning.pytorch.callbacks import ModelCheckpoint

from modern_hopfield_attention.lightning_models import LitViT
from modern_hopfield_attention.data import VIT_MODEL_SIZE, NUM_CLASSES

# args
# TODO move to data/arguments.py
parser = argparse.ArgumentParser()
parser.add_argument_group("Main Config")
parser.add_argument(
    "--dataset_type",
    type=str,
    choices=["torch/cifar10", "torch/cifar100", "torch/imagenet"],
    required=True,
)
parser.add_argument(
    "--universal",
    action="store_true",
)
parser.add_argument(
    "--model_size",
    type=str,
    choices=["tiny", "small", "base", "large", "huge", "custom"],
    required=True,
)
parser.add_argument(
    "--run_name",
    type=str,
    required=True,
)
# result directory
parser.add_argument(
    "--save_dir",
    type=str,
    required=True,
)
# parser.add_argument(
#     "--wo_mlp",
#     action="store_true",
# )
parser.add_argument(
    "--devices",
    type=int,
    required=True,
)
parser.add_argument_group("Ablation model Config")
parser.add_argument(
    "--depth",
    type=int,
)
parser.add_argument(
    "--num_heads",
    type=int,
    default=3,  # same to tiny setting
)
parser.add_argument(
    "--embed_dim",
    type=int,
    default=192,  # same to tiny setting
)
parser.add_argument(
    "--batch_size",
    type=int,
    default=256,  # same to tiny setting
)
parser.add_argument(
    "--drop_path_rate",
    type=int,
    default=0.1,
)
parser.add_argument(
    "--learning_rate",
    type=float,
    default=2.5e-4,  # same to tiny setting
)
parser.add_argument(
    "--epochs",
    type=int,
    default=300,  # same to main experience
)

parser.add_argument_group("Optional Config")
# model config
parser.add_argument(
    "--img_size",
    type=int,
    default=224,
)
parser.add_argument(
    "--in_chans",
    type=int,
    default=3,
)
parser.add_argument(
    "--patch_size",
    type=int,
    default=16,
)
parser.add_argument(
    "--class_token",
    type=bool,
    default=True,
)
# augmentation config
## mixup & cutmix
parser.add_argument(
    "--mixup",
    type=float,
    default=0.8,
)
parser.add_argument(
    "--cutmix",
    type=float,
    default=1.0,
)
parser.add_argument(
    "--label_smoothing",
    type=float,
    default=0.1,
)
## random_erasing
parser.add_argument(
    "--re_prob",
    type=float,
    default=0.25,
)
parser.add_argument(
    "--re_mode",
    type=str,
    default="const",
    choices=["const", "rand", "pixel"],
)
parser.add_argument(
    "--re_count",
    type=int,
    default=1,
)
parser.add_argument(
    "--re_split",
    type=bool,
    default=False,
)
## crop
parser.add_argument(
    "--crop_pct",
    type=float,
    default=0.875,
)
## auto augment
parser.add_argument(
    "--auto_augment",
    type=str,
    default="rand-m9-mstd0.5",
)

# optimizer & scheduler
parser.add_argument(
    "--eps",
    type=float,
    default=1e-8,
)
parser.add_argument(
    "--betas",
    type=tuple[float, float],
    default=(0.9, 0.999),
)
parser.add_argument(
    "--weight_decay",
    type=float,
    default=0.05,
)
parser.add_argument(
    "--lr_min",
    type=float,
    default=1e-6,
)
parser.add_argument(
    "--t_initial",
    type=int,
    default=300,
)
parser.add_argument(
    "--warmup_t",
    type=int,
    default=20,
)
parser.add_argument(
    "--warmup_lr_init",
    type=float,
    default=1e-6,
)
# dataset directory
parser.add_argument(
    "--dataset_dir",
    type=str,
    default="dataset",
)
parser.add_argument(
    "--num_workers",
    type=int,
    default=1,
)
# weights & biases
parser.add_argument(
    "--project_name",
    type=str,
    default="MHA",
)
parser.add_argument(
    "--ckpt_path",
    type=str,
    default=None,
)
parser.add_argument(
    "--version",
    type=str,
    default=None,
)
parser.add_argument(
    "--offline",
    action="store_true",
)

args = parser.parse_args()


def main(args: argparse.Namespace) -> None:

    if args.save_dir:
        os.makedirs(args.save_dir, exist_ok=True)
    args.num_classes = NUM_CLASSES[args.dataset_type]

    # ablation setting
    if args.model_size != "custom":
        args.num_heads = VIT_MODEL_SIZE[args.model_size]["num_heads"]
        args.depth = VIT_MODEL_SIZE[args.model_size]["depth"]
        args.embed_dim = VIT_MODEL_SIZE[args.model_size]["embedding_dim"]
        args.batch_size = VIT_MODEL_SIZE[args.model_size]["batch_size"]
        args.learning_rate = VIT_MODEL_SIZE[args.model_size]["leaning_rate"]

    model = LitViT(args=args)

    # logger
    wandb_logger = WandbLogger(
        project=args.project_name,
        name=args.run_name,
        save_dir=args.save_dir,
        log_model=not args.offline,
        version=args.version,
        offline=args.offline,
    )
    csv_logger = CSVLogger(
        args.save_dir,
    )

    # checkpoint
    checkpoint_callback = ModelCheckpoint(
        args.save_dir,
        monitor="valid/loss",
        mode="min",
        save_last=True,
    )

    trainer = L.Trainer(
        max_epochs=args.epochs,
        logger=[wandb_logger, csv_logger],
        callbacks=[checkpoint_callback],
        devices=args.devices,
    )

    trainer.fit(
        model=model,
        ckpt_path=args.ckpt_path,
    )


if __name__ == "__main__":
    main(args=args)
